{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G4vEEajUbvNc"
   },
   "source": [
    "# Ungraded Lab: Hyperparameter tuning and model training with TFX\n",
    "\n",
    "In this lab, you will be again doing hyperparameter tuning but this time, it will be within a [Tensorflow Extended (TFX)](https://www.tensorflow.org/tfx/) pipeline. \n",
    "\n",
    "We have already introduced some TFX components in Course 2 of this specialization related to data ingestion, validation, and transformation. In this notebook, you will get to work with two more which are related to model development and training: *Tuner* and *Trainer*.\n",
    "\n",
    "<img src='https://www.tensorflow.org/tfx/guide/images/prog_trainer.png' alt='tfx pipeline'>\n",
    "image source: https://www.tensorflow.org/tfx/guide\n",
    "\n",
    "* The *Tuner* utilizes the [Keras Tuner](https://keras-team.github.io/keras-tuner/) API under the hood to tune your model's hyperparameters.\n",
    "* You can get the best set of hyperparameters from the Tuner component and feed it into the *Trainer* component to optimize your model for training.\n",
    "\n",
    "You will again be working with the [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and will feed it though the TFX pipeline up to the Trainer component.You will quickly review the earlier components from Course 2, then focus on the two new components introduced.\n",
    "\n",
    "Let's begin!\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MUXex9ctTuDB"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YEFWSi_-umNz"
   },
   "source": [
    "### Install TFX\n",
    "\n",
    "You will first install [TFX](https://www.tensorflow.org/tfx), a framework for developing end-to-end machine learning pipelines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IqR2PQG4ZaZ0"
   },
   "outputs": [],
   "source": [
    "!pip install -U pip\n",
    "!pip install tfx==1.12.0\n",
    "!pip install keras_tuner==1.2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Yr2ulfeNvvom"
   },
   "source": [
    "*Note: In Google Colab, you need to restart the runtime at this point to finalize updating the packages you just installed. You can do so by clicking the `Restart Runtime` at the end of the output cell above (after installation), or by selecting `Runtime > Restart Runtime` in the Menu bar. **Please do not proceed to the next section without restarting.** You can also ignore the errors about version incompatibility of some of the bundled packages because we won't be using those in this notebook.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T_MPhjWTvNSr"
   },
   "source": [
    "### Imports\n",
    "\n",
    "You will then import the packages you will need for this exercise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_leAIdFKAxAD"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pprint\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from tensorflow import keras\n",
    "from absl import logging\n",
    "\n",
    "from tfx import v1 as tfx\n",
    "from tfx.proto import example_gen_pb2, trainer_pb2\n",
    "from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext\n",
    "\n",
    "tf.get_logger().propagate = False\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "pp = pprint.PrettyPrinter()\n",
    "logging.set_verbosity(logging.ERROR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ReV_UXOgCZvx"
   },
   "source": [
    "## Download and prepare the dataset\n",
    "\n",
    "As mentioned earlier, you will be using the Fashion MNIST dataset just like in the previous lab. This will allow you to compare the similarities and differences when using Keras Tuner as a standalone library and within an ML pipeline.\n",
    "\n",
    "You will first need to setup the directories that you will use to store the dataset, as well as the pipeline artifacts and metadata store."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cNQlwf5_t8Fc"
   },
   "outputs": [],
   "source": [
    "# Location of the pipeline metadata store\n",
    "_pipeline_root = './pipeline/'\n",
    "\n",
    "# Directory of the raw data files\n",
    "_data_root = './data/fmnist'\n",
    "\n",
    "# Temporary directory\n",
    "tempdir = './tempdir'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BqwtVwAsslgN"
   },
   "outputs": [],
   "source": [
    "# Create the dataset directory\n",
    "!mkdir -p {_data_root}\n",
    "\n",
    "# Create the TFX pipeline files directory\n",
    "!mkdir {_pipeline_root}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JyjfgG0ax9uv"
   },
   "source": [
    "You will now download FashionMNIST from [Tensorflow Datasets](https://www.tensorflow.org/datasets). The `with_info` flag will be set to `True` so you can display information about the dataset in the next cell (i.e. using `ds_info`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aUzvq3WFvKyl"
   },
   "outputs": [],
   "source": [
    "# Download the dataset\n",
    "ds, ds_info = tfds.load('fashion_mnist', data_dir=tempdir, with_info=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "74BnhUcG1A-x"
   },
   "outputs": [],
   "source": [
    "# Display info about the dataset\n",
    "print(ds_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PuByMtNV13JH"
   },
   "source": [
    "You can review the downloaded files with the code below. For this lab, you will be using the *train* TFRecord so you will need to take note of its filename. You will not use the *test* TFRecord in this lab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "A501bxQd1Qxo"
   },
   "outputs": [],
   "source": [
    "# Define the location of the train tfrecord downloaded via TFDS\n",
    "tfds_data_path = f'{tempdir}/{ds_info.name}/{ds_info.version}'\n",
    "\n",
    "# Display contents of the TFDS data directory\n",
    "os.listdir(tfds_data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Obla1v0RzXdB"
   },
   "source": [
    "You will then copy the train split from the downloaded data so it can be consumed by the ExampleGen component in the next step. This component requires that your files are in a directory without extra files (e.g. JSONs and TXT files)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "49ZklvN8d64e"
   },
   "outputs": [],
   "source": [
    "# Define the train tfrecord filename\n",
    "train_filename = 'fashion_mnist-train.tfrecord-00000-of-00001'\n",
    "\n",
    "# Copy the train tfrecord into the data root folder\n",
    "!cp {tfds_data_path}/{train_filename} {_data_root}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MUy9vZq72ueR"
   },
   "source": [
    "## TFX Pipeline\n",
    "\n",
    "With the setup complete, you can now proceed to creating the pipeline. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X1gu2Bbi226z"
   },
   "source": [
    "### Initialize the Interactive Context\n",
    "\n",
    "You will start by initializing the [InteractiveContext](https://github.com/tensorflow/tfx/blob/master/tfx/orchestration/experimental/interactive/interactive_context.py) so you can run the components within this Colab environment. You can safely ignore the warning because you will just be using a local SQLite file for the metadata store."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GeCZ5mAvvlD4"
   },
   "outputs": [],
   "source": [
    "# Initialize the InteractiveContext\n",
    "context = InteractiveContext(pipeline_root=_pipeline_root)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WQwR6Cex3azZ"
   },
   "source": [
    "### ExampleGen\n",
    "\n",
    "You will start the pipeline by ingesting the TFRecord you set aside. The [ImportExampleGen](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/ImportExampleGen) consumes TFRecords and you can specify splits as shown below. For this exercise, you will split the train tfrecord to use 80% for the train set, and the remaining 20% as eval/validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Xolw1d8lvqNW"
   },
   "outputs": [],
   "source": [
    "# Specify 80/20 split for the train and eval set\n",
    "output = example_gen_pb2.Output(\n",
    "    split_config=example_gen_pb2.SplitConfig(splits=[\n",
    "        example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=8),\n",
    "        example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2),\n",
    "    ]))\n",
    "\n",
    "# Ingest the data through ExampleGen\n",
    "example_gen = tfx.components.ImportExampleGen(input_base=_data_root, output_config=output)\n",
    "\n",
    "# Run the component\n",
    "context.run(example_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dIdWfRWGxvHp"
   },
   "outputs": [],
   "source": [
    "# Print split names and URI\n",
    "artifact = example_gen.outputs['examples'].get()[0]\n",
    "print(artifact.split_names, artifact.uri)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "os6NhLaY4oB3"
   },
   "source": [
    "### StatisticsGen\n",
    "\n",
    "Next, you will compute the statistics of the dataset with the [StatisticsGen](https://www.tensorflow.org/tfx/guide/statsgen) component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pVDS4oEIzZ83"
   },
   "outputs": [],
   "source": [
    "# Run StatisticsGen\n",
    "statistics_gen = tfx.components.StatisticsGen(\n",
    "    examples=example_gen.outputs['examples'])\n",
    "\n",
    "context.run(statistics_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D48bfGK95sES"
   },
   "source": [
    "### SchemaGen\n",
    "\n",
    "You can then infer the dataset schema with [SchemaGen](https://www.tensorflow.org/tfx/guide/schemagen). This will be used to validate incoming data to ensure that it is formatted correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7UhV3Jr7zp7p"
   },
   "outputs": [],
   "source": [
    "# Run SchemaGen\n",
    "schema_gen = tfx.components.SchemaGen(\n",
    "      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)\n",
    "context.run(schema_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EtS2ZEgCzvAf"
   },
   "outputs": [],
   "source": [
    "# Visualize the results\n",
    "context.show(schema_gen.outputs['schema'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2_yXqq1y6LR6"
   },
   "source": [
    "### ExampleValidator\n",
    "\n",
    "You can assume that the dataset is clean since we downloaded it from TFDS. But just to review, let's run it through [ExampleValidator](https://www.tensorflow.org/tfx/guide/exampleval) to detect if there are anomalies within the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EaTJiYPpzzZM"
   },
   "outputs": [],
   "source": [
    "# Run ExampleValidator\n",
    "example_validator = tfx.components.ExampleValidator(\n",
    "    statistics=statistics_gen.outputs['statistics'],\n",
    "    schema=schema_gen.outputs['schema'])\n",
    "context.run(example_validator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "v6YzedBSz5KE"
   },
   "outputs": [],
   "source": [
    "# Visualize the results. There should be no anomalies.\n",
    "context.show(example_validator.outputs['anomalies'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tpUFIO9M6yMH"
   },
   "source": [
    "### Transform\n",
    "\n",
    "Let's now use the [Transform](https://www.tensorflow.org/tfx/guide/transform) component to scale the image pixels and convert the data types to float. You will first define the transform module containing these operations before you run the component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xL4zrcJ7z9K9"
   },
   "outputs": [],
   "source": [
    "_transform_module_file = 'fmnist_transform.py'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "43xmp2UD0Cc5"
   },
   "outputs": [],
   "source": [
    "%%writefile {_transform_module_file}\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_transform as tft\n",
    "\n",
    "# Keys\n",
    "_LABEL_KEY = 'label'\n",
    "_IMAGE_KEY = 'image'\n",
    "\n",
    "\n",
    "def _transformed_name(key):\n",
    "    return key + '_xf'\n",
    "\n",
    "def _image_parser(image_str):\n",
    "    '''converts the images to a float tensor'''\n",
    "    image = tf.image.decode_image(image_str, channels=1)\n",
    "    image = tf.reshape(image, (28, 28, 1))\n",
    "    image = tf.cast(image, tf.float32)\n",
    "    return image\n",
    "\n",
    "\n",
    "def _label_parser(label_id):\n",
    "    '''converts the labels to a float tensor'''\n",
    "    label = tf.cast(label_id, tf.float32)\n",
    "    return label\n",
    "\n",
    "\n",
    "def preprocessing_fn(inputs):\n",
    "    \"\"\"tf.transform's callback function for preprocessing inputs.\n",
    "    Args:\n",
    "        inputs: map from feature keys to raw not-yet-transformed features.\n",
    "    Returns:\n",
    "        Map from string feature key to transformed feature operations.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Convert the raw image and labels to a float array\n",
    "    with tf.device(\"/cpu:0\"):\n",
    "        outputs = {\n",
    "            _transformed_name(_IMAGE_KEY):\n",
    "                tf.map_fn(\n",
    "                    _image_parser,\n",
    "                    tf.squeeze(inputs[_IMAGE_KEY], axis=1),\n",
    "                    dtype=tf.float32),\n",
    "            _transformed_name(_LABEL_KEY):\n",
    "                tf.map_fn(\n",
    "                    _label_parser,\n",
    "                    inputs[_LABEL_KEY],\n",
    "                    dtype=tf.float32)\n",
    "        }\n",
    "    \n",
    "    # scale the pixels from 0 to 1\n",
    "    outputs[_transformed_name(_IMAGE_KEY)] = tft.scale_to_0_1(outputs[_transformed_name(_IMAGE_KEY)])\n",
    "    \n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0uNYsebhLC69"
   },
   "source": [
    "You will run the component by passing in the examples, schema, and transform module file.\n",
    "\n",
    "*Note: You can safely ignore the warnings and `udf_utils` related errors.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qthHA2hO1JST"
   },
   "outputs": [],
   "source": [
    "# Setup the Transform component\n",
    "transform = tfx.components.Transform(\n",
    "    examples=example_gen.outputs['examples'],\n",
    "    schema=schema_gen.outputs['schema'],\n",
    "    module_file=os.path.abspath(_transform_module_file))\n",
    "\n",
    "# Run the component\n",
    "context.run(transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QZkbL7sO8Y1N"
   },
   "source": [
    "### Tuner\n",
    "\n",
    "As the name suggests, the [Tuner](https://www.tensorflow.org/tfx/guide/tuner) component tunes the hyperparameters of your model. To use this, you will need to provide a *tuner module file* which contains a `tuner_fn()` function. In this function, you will mostly do the same steps as you did in the previous ungraded lab but with some key differences in handling the dataset. \n",
    "\n",
    "The Transform component earlier saved the transformed examples as TFRecords compressed in `.gz` format and you will need to load that into memory. Once loaded, you will need to create batches of features and labels so you can finally use it for hypertuning. This process is modularized in the `_input_fn()` below. \n",
    "\n",
    "Going back, the `tuner_fn()` function will return a `TunerFnResult` [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple) containing your `tuner` object and a set of arguments to pass to `tuner.search()` method. You will see these in action in the following cells. When reviewing the module file, we recommend viewing the `tuner_fn()` first before looking at the other auxiliary functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aE1PLAs_6CVt"
   },
   "outputs": [],
   "source": [
    "# Declare name of module file\n",
    "_tuner_module_file = 'tuner.py'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G0F-XhqVlUDB"
   },
   "outputs": [],
   "source": [
    "%%writefile {_tuner_module_file}\n",
    "\n",
    "# Define imports\n",
    "from kerastuner.engine import base_tuner\n",
    "import kerastuner as kt\n",
    "from tensorflow import keras\n",
    "from typing import NamedTuple, Dict, Text, Any, List\n",
    "from tfx.components.trainer.fn_args_utils import FnArgs, DataAccessor\n",
    "import tensorflow as tf\n",
    "import tensorflow_transform as tft\n",
    "\n",
    "# Declare namedtuple field names\n",
    "TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),\n",
    "                                             ('fit_kwargs', Dict[Text, Any])])\n",
    "\n",
    "# Input key\n",
    "_IMAGE_KEY = 'image_xf'\n",
    "\n",
    "# Label key\n",
    "_LABEL_KEY = 'label_xf'\n",
    "\n",
    "# Callback for the search strategy\n",
    "stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)\n",
    "\n",
    "\n",
    "def _gzip_reader_fn(filenames):\n",
    "  '''Load compressed dataset\n",
    "  \n",
    "  Args:\n",
    "    filenames - filenames of TFRecords to load\n",
    "\n",
    "  Returns:\n",
    "    TFRecordDataset loaded from the filenames\n",
    "  '''\n",
    "\n",
    "  # Load the dataset. Specify the compression type since it is saved as `.gz`\n",
    "  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')\n",
    "  \n",
    "\n",
    "def _input_fn(file_pattern,\n",
    "              tf_transform_output,\n",
    "              num_epochs=None,\n",
    "              batch_size=32) -> tf.data.Dataset:\n",
    "  '''Create batches of features and labels from TF Records\n",
    "\n",
    "  Args:\n",
    "    file_pattern - List of files or patterns of file paths containing Example records.\n",
    "    tf_transform_output - transform output graph\n",
    "    num_epochs - Integer specifying the number of times to read through the dataset. \n",
    "            If None, cycles through the dataset forever.\n",
    "    batch_size - An int representing the number of records to combine in a single batch.\n",
    "\n",
    "  Returns:\n",
    "    A dataset of dict elements, (or a tuple of dict elements and label). \n",
    "    Each dict maps feature keys to Tensor or SparseTensor objects.\n",
    "  '''\n",
    "\n",
    "  # Get feature specification based on transform output\n",
    "  transformed_feature_spec = (\n",
    "      tf_transform_output.transformed_feature_spec().copy())\n",
    "  \n",
    "  # Create batches of features and labels\n",
    "  dataset = tf.data.experimental.make_batched_features_dataset(\n",
    "      file_pattern=file_pattern,\n",
    "      batch_size=batch_size,\n",
    "      features=transformed_feature_spec,\n",
    "      reader=_gzip_reader_fn,\n",
    "      num_epochs=num_epochs,\n",
    "      label_key=_LABEL_KEY)\n",
    "  \n",
    "  return dataset\n",
    "\n",
    "\n",
    "def model_builder(hp):\n",
    "  '''\n",
    "  Builds the model and sets up the hyperparameters to tune.\n",
    "\n",
    "  Args:\n",
    "    hp - Keras tuner object\n",
    "\n",
    "  Returns:\n",
    "    model with hyperparameters to tune\n",
    "  '''\n",
    "\n",
    "  # Initialize the Sequential API and start stacking the layers\n",
    "  model = keras.Sequential()\n",
    "  model.add(keras.layers.Input(shape=(28, 28, 1), name=_IMAGE_KEY))\n",
    "  model.add(keras.layers.Flatten())\n",
    "\n",
    "  # Tune the number of units in the first Dense layer\n",
    "  # Choose an optimal value between 32-512\n",
    "  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)\n",
    "  model.add(keras.layers.Dense(units=hp_units, activation='relu', name='dense_1'))\n",
    "\n",
    "  # Add next layers\n",
    "  model.add(keras.layers.Dropout(0.2))\n",
    "  model.add(keras.layers.Dense(10, activation='softmax'))\n",
    "\n",
    "  # Tune the learning rate for the optimizer\n",
    "  # Choose an optimal value from 0.01, 0.001, or 0.0001\n",
    "  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])\n",
    "\n",
    "  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),\n",
    "                loss=keras.losses.SparseCategoricalCrossentropy(),\n",
    "                metrics=['accuracy'])\n",
    "\n",
    "  return model\n",
    "\n",
    "def tuner_fn(fn_args: FnArgs) -> TunerFnResult:\n",
    "  \"\"\"Build the tuner using the KerasTuner API.\n",
    "  Args:\n",
    "    fn_args: Holds args as name/value pairs.\n",
    "\n",
    "      - working_dir: working dir for tuning.\n",
    "      - train_files: List of file paths containing training tf.Example data.\n",
    "      - eval_files: List of file paths containing eval tf.Example data.\n",
    "      - train_steps: number of train steps.\n",
    "      - eval_steps: number of eval steps.\n",
    "      - schema_path: optional schema of the input data.\n",
    "      - transform_graph_path: optional transform graph produced by TFT.\n",
    "  \n",
    "  Returns:\n",
    "    A namedtuple contains the following:\n",
    "      - tuner: A BaseTuner that will be used for tuning.\n",
    "      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the\n",
    "                    model , e.g., the training and validation dataset. Required\n",
    "                    args depend on the above tuner's implementation.\n",
    "  \"\"\"\n",
    "\n",
    "  # Define tuner search strategy\n",
    "  tuner = kt.Hyperband(model_builder,\n",
    "                     objective='val_accuracy',\n",
    "                     max_epochs=10,\n",
    "                     factor=3,\n",
    "                     directory=fn_args.working_dir,\n",
    "                     project_name='kt_hyperband')\n",
    "\n",
    "  # Load transform output\n",
    "  tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)\n",
    "\n",
    "  # Use _input_fn() to extract input features and labels from the train and val set\n",
    "  train_set = _input_fn(fn_args.train_files[0], tf_transform_output)\n",
    "  val_set = _input_fn(fn_args.eval_files[0], tf_transform_output)\n",
    "\n",
    "\n",
    "  return TunerFnResult(\n",
    "      tuner=tuner,\n",
    "      fit_kwargs={ \n",
    "          \"callbacks\":[stop_early],\n",
    "          'x': train_set,\n",
    "          'validation_data': val_set,\n",
    "          'steps_per_epoch': fn_args.train_steps,\n",
    "          'validation_steps': fn_args.eval_steps\n",
    "      }\n",
    "  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lzJbeuXNtI-7"
   },
   "source": [
    "With the module defined, you can now setup the Tuner component. You can see the description of each argument [here](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/Tuner). \n",
    "\n",
    "Notice that we passed a `num_steps` argument to the train and eval args and this was used in the `steps_per_epoch` and `validation_steps` arguments in the tuner module above. This can be useful if you don't want to go through the entire dataset when tuning. For example, if you have 10GB of training data, it would be incredibly time consuming if you will iterate through it entirely just for one epoch and one set of hyperparameters. You can set the number of steps so your program will only go through a fraction of the dataset. \n",
    "\n",
    "You can compute for the total number of steps in one epoch by: `number of examples / batch size`. For this particular example, we have `48000 examples / 32 (default size)` which equals `1500` steps per epoch for the train set (compute val steps from 12000 examples). Since you passed `500` in the `num_steps` of the train args, this means that some examples will be skipped. This will likely result in lower accuracy readings but will save time in doing the hypertuning. Try modifying this value later and see if you arrive at the same set of hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VqVSc6sS5A1m"
   },
   "outputs": [],
   "source": [
    "# Setup the Tuner component\n",
    "tuner = tfx.components.Tuner(\n",
    "    module_file=_tuner_module_file,\n",
    "    examples=transform.outputs['transformed_examples'],\n",
    "    transform_graph=transform.outputs['transform_graph'],\n",
    "    schema=schema_gen.outputs['schema'],\n",
    "    train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=500),\n",
    "    eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=100)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HdycQnAR7AvG"
   },
   "outputs": [],
   "source": [
    "# Run the component. This will take around 10 minutes to run.\n",
    "# When done, it will summarize the results and show the 10 best trials.\n",
    "context.run(tuner, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uW50JS0d9Hd4"
   },
   "source": [
    "### Trainer\n",
    "\n",
    "Like the Tuner component, the [Trainer](https://www.tensorflow.org/tfx/guide/trainer) component also requires a module file to setup the training process. It will look for a `run_fn()` function that defines and trains the model. The steps will look similar to the tuner module file:\n",
    "\n",
    "* Define the model - You can get the results of the Tuner component through the `fn_args.hyperparameters` argument. You will see it passed into the `model_builder()` function below. If you didn't run `Tuner`, then you can just explicitly define the number of hidden units and learning rate.\n",
    "\n",
    "* Load the train and validation sets - You have done this in the Tuner component. For this module, you will pass in a `num_epochs` value (10) to indicate how many batches will be prepared. You can opt not to do this and pass a `num_steps` value as before.\n",
    "\n",
    "* Setup and train the model - This will look very familiar if you're already used to the [Keras Models Training API](https://keras.io/api/models/model_training_apis/). You can pass in callbacks like the [TensorBoard callback](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/TensorBoard) so you can visualize the results later.\n",
    "\n",
    "* Save the model - This is needed so you can analyze and serve your model. You will get to do this in later parts of the course and specialization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "abSJjDM2ipKS"
   },
   "outputs": [],
   "source": [
    "# Declare trainer module file\n",
    "_trainer_module_file = 'trainer.py'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QdgbwOFFihSg"
   },
   "outputs": [],
   "source": [
    "%%writefile {_trainer_module_file}\n",
    "\n",
    "from tensorflow import keras\n",
    "from typing import NamedTuple, Dict, Text, Any, List\n",
    "from tfx.components.trainer.fn_args_utils import FnArgs, DataAccessor\n",
    "import tensorflow as tf\n",
    "import tensorflow_transform as tft\n",
    "\n",
    "# Input key\n",
    "_IMAGE_KEY = 'image_xf'\n",
    "\n",
    "# Label key\n",
    "_LABEL_KEY = 'label_xf'\n",
    "\n",
    "def _gzip_reader_fn(filenames):\n",
    "  '''Load compressed dataset\n",
    "  \n",
    "  Args:\n",
    "    filenames - filenames of TFRecords to load\n",
    "\n",
    "  Returns:\n",
    "    TFRecordDataset loaded from the filenames\n",
    "  '''\n",
    "\n",
    "  # Load the dataset. Specify the compression type since it is saved as `.gz`\n",
    "  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')\n",
    "  \n",
    "\n",
    "def _input_fn(file_pattern,\n",
    "              tf_transform_output,\n",
    "              num_epochs=None,\n",
    "              batch_size=32) -> tf.data.Dataset:\n",
    "  '''Create batches of features and labels from TF Records\n",
    "\n",
    "  Args:\n",
    "    file_pattern - List of files or patterns of file paths containing Example records.\n",
    "    tf_transform_output - transform output graph\n",
    "    num_epochs - Integer specifying the number of times to read through the dataset. \n",
    "            If None, cycles through the dataset forever.\n",
    "    batch_size - An int representing the number of records to combine in a single batch.\n",
    "\n",
    "  Returns:\n",
    "    A dataset of dict elements, (or a tuple of dict elements and label). \n",
    "    Each dict maps feature keys to Tensor or SparseTensor objects.\n",
    "  '''\n",
    "  transformed_feature_spec = (\n",
    "      tf_transform_output.transformed_feature_spec().copy())\n",
    "  \n",
    "  dataset = tf.data.experimental.make_batched_features_dataset(\n",
    "      file_pattern=file_pattern,\n",
    "      batch_size=batch_size,\n",
    "      features=transformed_feature_spec,\n",
    "      reader=_gzip_reader_fn,\n",
    "      num_epochs=num_epochs,\n",
    "      label_key=_LABEL_KEY)\n",
    "  \n",
    "  return dataset\n",
    "\n",
    "\n",
    "def model_builder(hp):\n",
    "  '''\n",
    "  Builds the model and sets up the hyperparameters to tune.\n",
    "\n",
    "  Args:\n",
    "    hp - Keras tuner object\n",
    "\n",
    "  Returns:\n",
    "    model with hyperparameters to tune\n",
    "  '''\n",
    "\n",
    "  # Initialize the Sequential API and start stacking the layers\n",
    "  model = keras.Sequential()\n",
    "  model.add(keras.layers.Input(shape=(28, 28, 1), name=_IMAGE_KEY))\n",
    "  model.add(keras.layers.Flatten())\n",
    "\n",
    "  # Get the number of units from the Tuner results\n",
    "  hp_units = hp.get('units')\n",
    "  model.add(keras.layers.Dense(units=hp_units, activation='relu'))\n",
    "\n",
    "  # Add next layers\n",
    "  model.add(keras.layers.Dropout(0.2))\n",
    "  model.add(keras.layers.Dense(10, activation='softmax'))\n",
    "\n",
    "  # Get the learning rate from the Tuner results\n",
    "  hp_learning_rate = hp.get('learning_rate')\n",
    "\n",
    "  # Setup model for training\n",
    "  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),\n",
    "                loss=keras.losses.SparseCategoricalCrossentropy(),\n",
    "                metrics=['accuracy'])\n",
    "\n",
    "  # Print the model summary\n",
    "  model.summary()\n",
    "  \n",
    "  return model\n",
    "\n",
    "\n",
    "def run_fn(fn_args: FnArgs) -> None:\n",
    "  \"\"\"Defines and trains the model.\n",
    "  Args:\n",
    "    fn_args: Holds args as name/value pairs. Refer here for the complete attributes: \n",
    "    https://www.tensorflow.org/tfx/api_docs/python/tfx/components/trainer/fn_args_utils/FnArgs#attributes\n",
    "  \"\"\"\n",
    "\n",
    "  # Callback for TensorBoard\n",
    "  tensorboard_callback = tf.keras.callbacks.TensorBoard(\n",
    "      log_dir=fn_args.model_run_dir, update_freq='batch')\n",
    "  \n",
    "  # Load transform output\n",
    "  tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)\n",
    "  \n",
    "  # Create batches of data good for 10 epochs\n",
    "  train_set = _input_fn(fn_args.train_files[0], tf_transform_output, 10)\n",
    "  val_set = _input_fn(fn_args.eval_files[0], tf_transform_output, 10)\n",
    "\n",
    "  # Load best hyperparameters\n",
    "  hp = fn_args.hyperparameters.get('values')\n",
    "\n",
    "  # Build the model\n",
    "  model = model_builder(hp)\n",
    "\n",
    "  # Train the model\n",
    "  model.fit(\n",
    "      x=train_set,\n",
    "      validation_data=val_set,\n",
    "      callbacks=[tensorboard_callback]\n",
    "      )\n",
    "  \n",
    "  # Save the model\n",
    "  model.save(fn_args.serving_model_dir, save_format='tf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lu3fQwFX6E8Q"
   },
   "source": [
    "You can pass the output of the `Tuner` component to the `Trainer` by filling the `hyperparameters` argument with the `Tuner` output. This is indicated by the `tuner.outputs['best_hyperparameters']` below. You can see the definition of the other arguments [here](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/Trainer)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u0JOuqSKGsoQ"
   },
   "outputs": [],
   "source": [
    "# Setup the Trainer component\n",
    "trainer = tfx.components.Trainer(\n",
    "    module_file=_trainer_module_file,\n",
    "    examples=transform.outputs['transformed_examples'],\n",
    "    hyperparameters=tuner.outputs['best_hyperparameters'],\n",
    "    transform_graph=transform.outputs['transform_graph'],\n",
    "    schema=schema_gen.outputs['schema'],\n",
    "    train_args=trainer_pb2.TrainArgs(splits=['train']),\n",
    "    eval_args=trainer_pb2.EvalArgs(splits=['eval']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lQfTLKGf7BFk"
   },
   "source": [
    "Take note that when re-training your model, you don't always have to retune your hyperparameters. Once you have a set that you think performs well, you can just import it with the `Importer` component as shown in the [official docs](https://www.tensorflow.org/tfx/guide/tuner):\n",
    "\n",
    "```\n",
    "hparams_importer = Importer(\n",
    "    # This can be Tuner's output file or manually edited file. The file contains\n",
    "    # text format of hyperparameters (keras_tuner.HyperParameters.get_config())\n",
    "    source_uri='path/to/best_hyperparameters.txt',\n",
    "    artifact_type=HyperParameters,\n",
    ").with_id('import_hparams')\n",
    "\n",
    "trainer = Trainer(\n",
    "    ...\n",
    "    # An alternative is directly use the tuned hyperparameters in Trainer's user\n",
    "    # module code and set hyperparameters to None here.\n",
    "    hyperparameters = hparams_importer.outputs['result'])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IwM2743um1w3"
   },
   "outputs": [],
   "source": [
    "# Run the component\n",
    "context.run(trainer, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PiuE7i0A8qEb"
   },
   "source": [
    "Your model should now be saved in your pipeline directory and you can navigate through it as shown below. The file is saved as `saved_model.pb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mQPZBkw_yl2i"
   },
   "outputs": [],
   "source": [
    "# Get artifact uri of trainer model output\n",
    "model_artifact_dir = trainer.outputs['model'].get()[0].uri\n",
    "\n",
    "# List subdirectories artifact uri\n",
    "print(f'contents of model artifact directory:{os.listdir(model_artifact_dir)}')\n",
    "\n",
    "# Define the model directory\n",
    "model_dir = os.path.join(model_artifact_dir, 'Format-Serving')\n",
    "\n",
    "# List contents of model directory\n",
    "print(f'contents of model directory: {os.listdir(model_dir)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bu5Bsn0J9ol3"
   },
   "source": [
    "You can also visualize the training results by loading the logs saved by the Tensorboard callback."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GPqoMMXv5NoY"
   },
   "outputs": [],
   "source": [
    "model_run_artifact_dir = trainer.outputs['model_run'].get()[0].uri\n",
    "\n",
    "%load_ext tensorboard\n",
    "%tensorboard --logdir {model_run_artifact_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q6H6eCKC9xLp"
   },
   "source": [
    "***Congratulations! You have now created an ML pipeline that includes hyperparameter tuning and model training. You will know more about the next components in future lessons but in the next section, you will first learn about a framework for automatically building ML pipelines: AutoML. Enjoy the rest of the course!***"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
